{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4af2332a",
   "metadata": {},
   "source": [
    "## Spatio-Temporal Convolution\n",
    "\n",
    "The STConv module contains two temporal and one spatial graph convolution. <br />\n",
    "\n",
    "TConv &#10141; SConv &#10141; ReLU &#10141; TConv &#10141; BatchNorm.<br />\n",
    "<br />"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6c060d97",
   "metadata": {},
   "source": [
    "### STConv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "7e473b3a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 500/500 [17:27<00:00,  2.10s/it]\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEWCAYAAABrDZDcAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAA2lklEQVR4nO3dd3hUZfbA8e9JrySBBAiETuhIMWJDBBVFsJe1rK5t15+9brGvZYt91V3XtXdUdC2oLFIEQamhQ2gBQkkIKaSQXub9/XFvhkkhGcpkkpnzeZ555s69d2bOjXjPvF2MMSillPJfAd4OQCmllHdpIlBKKT+niUAppfycJgKllPJzmgiUUsrPaSJQSik/p4lAqXZGROaLiBGR670di/INmghUmyEiGfYN7iJvx9IWiMh4+++R0eDQF8DLQFrrR6V8UZC3A1BKgYgEG2Oq3TnXGPMvT8ej/IuWCFS7ISIXi8hyETkgIjtF5FURibWPhYjImyKSLSKVIrJbRL61j4mI/M3eV2mf84OIdDrE90SKyHMisk1ESkRktYhcax/rKyIOEckXkWB7Xy/7l3u+HUeQiPxRRDaKSKmIpInIzS6f/7h9/hciMk1EyoFfN4hhPDDPfln3+cY+Vq9qSETes19/ICL/E5FyEZllx/VfO4bFItLH5fOHicj3IpIjIrn2eT2PwX8m1Q5pIlDtgohMBr4EjrOfDwC3AZ/ap/wG+C2QB7wNrABOsY+dCTwI1NrHFgDDgehDfN27wO/t86cBycAHInKVMWY78AvQETjLPv9X9vNnxpgq4CngGUCAqUAY8LqIXNfgey4F+gEfAtkNju0B/mtvH8CqCnr5EPHWuQYoAfYDE4E1QCywHTjJjgsR6Wr/DSYCPwPzgUuAH0QktIXvUD5IE4FqL+6wn/9mjLkOGA/UAOeIyAAg2D6+DvgYuAHobO+rO5aOdWO/A+gO7Gr4JSLSGbjcfjnRGHMj8JD9+k77+QP7+Qr7uS4RfCAi4hLrIqAUWG+/vrXB120HTjTG3GyMmel6wBiTDtRVAe03xtxjjLmnYbwN/GiMuRx4035djnWzr4t/lP18LRCH9ffYBWQCucAgYEIL36F8kLYRqPait/28EcAYkycieUBXoBfWzXk8cCFwJWCAOSJyMTAL+DfWDbCuuiUVuADYe4jvKTfG7LS3N9nPveznacArwEUiMghIAbYYY5aISAIQZZ93Q4PP7t/g9TJjTE1LF34YNtrPhfZzujHGISIH7NeR9nNv+3mw/WguRuUHtESg2osM+3kQgF2/H2/v2wnUGGOuADpg3dzmYP0avgQIxPqVHot1o/sA6+b922a+J9ylznygy/dgjCkCvgFigDfsY3WlhDysUgDACGOMGGME6/+1lAbfVdnCNdfaz+7+f1rbwus6GfbzV3Xx2TEmYlWdKT+jiUC1Rc+IyBKXxzjgVfvYQyLyHla9dhAw2xizBbhKRDZi1e/fjdUGANav41OAHVhVRvcBp7ocq8cYk4PVPRNgtoi8A/zNfu3aW6fuxn8aVunjQ/v9xiXWWXYD9idY1UCPH9ZfAXbbz0ki8paI/Okw338oH2Nd+8V2o/nrIjLH/r4ux+g7VDuiVUOqLRrQ4HVHY8zXIvIr4AHgMqwG0dexGoEBNmP9Gp+M1Qi8F/gL8B1WKWArVqNxrH3efzj4a76hG7FuihdjtQNsA140xkx1OecHrAbersBPxhjX9oZHgHzgeqwG3GJgFfCZm9cPgDEmQ0Sexyq53ARswGqEPirGmCwROR34KzAGGIvVVvAq1t9G+RnRhWmUUsq/adWQUkr5OU0ESinl5zQRKKWUn9NEoJRSfq7d9RqKj483vXv39nYYSinVrqxYsSLPGJPQ1LF2lwh69+5Namqqt8NQSql2RUR2HuqYVg0ppZSf00SglFJ+ThOBUkr5OU0ESinl5zQRKKWUn9NEoJRSfk4TgVJK+Tm/SQTLM/bzwqzNVNc6vB2KUkq1KX6TCFbuLOCfP6ZrIlBKqQb8JhEEBggANQ5df0EppVz5XSJwaCJQSql6/C4RaIlAKaXq87tEoCUCpZSqz38SgWiJQCmlmuI/icAuEdRqIlBKqXr8JhEEBWoiUEqppvhNIgjQqiGllGqS3ySCoADrUh1GE4FSSrnym0QQaF9pTa0mAqWUcuVHiUBLBEop1RQ/SgTWs7YRKKVUfX6UCKxL1V5DSilVn/8kAtHuo0op1RT/SQQ6oEwppZqkiUAppfyc3yWCGocuTKOUUq78LhFo91GllKrPbxJBUF2JQAeUKaVUPX6TCOrmGtISgVJK1ec3iaBu9lEdUKaUUvX5TSLQXkNKKdU0/0kEOqBMKaWa5D+JQEsESinVJE0ESinl5/wmEdR1H63VXkNKKVWPRxOBiEwSkc0iki4iDzRxvKeIzBORVSKyVkQmeyqWAC0RKKVUkzyWCEQkEHgVOBcYAlwlIkManPYIMM0YMwq4Evi3p+LRAWVKKdU0T5YIxgDpxpjtxpgq4FPgwgbnGKCDvR0DZHkqmACdYkIppZrkyUTQHdjt8nqPvc/V48A1IrIHmAHc2dQHicjNIpIqIqm5ublHFIyzRKBVQ0opVY+3G4uvAt4zxiQBk4EPRaRRTMaYN4wxKcaYlISEhCP6ogAdR6CUUk3yZCLIBHq4vE6y97m6CZgGYIxZDIQB8Z4IJkgbi5VSqkmeTATLgWQR6SMiIViNwdMbnLMLOBNARAZjJYIjq/tpgY4jUEqppnksERhjaoA7gB+AjVi9gzaIyJMicoF92v3A70RkDfAJcL0xnmnNFRECRBOBUko1FOTJDzfGzMBqBHbd95jLdhpwqidjcBUYIDqgTCmlGvB2Y3GrCgwQLREopVQD/pUIRBOBUko15F+JQEsESinViF8lgqDAAE0ESinVgF8lggARHVmslFIN+FUiCAoQah0Ob4ehlFJtil8lAquNwNtRKKVU2+KHiUAzgVJKuWo2EYhIoIhcICIDWysgTwoPDqSgrNrbYSilVJvSbCIwxtQCbwMnt044nnVq/3h+2pLLzPV7vR2KUkq1Ge5UDX0MXC8iQ0WkY93D04F5wsWjrOUQHv1mg5cjUUqptsOdRHAXcBqwFmtm0Fwgx5NBecrwpBiuSOlBRVWtt0NRSqk2w51J5xZgLSnpExKiQymrrsUYg9iL1SillD9rMREYY8a3QhytJiI0kFqHobLGQVhwoLfDUUopr2uxakhEYkTkPRHZZz/eEZGY1gjOE6JCrdz35cqGi6UppZR/cqeN4BXgN0CV/bgeeMlzIXlWZIiVCB76ah3bc0sAuPKNxVzx+mJvhqWUUl7jThvBucCzxpgHAETkGeAGj0blQZGhBy+5bkzBku37vRWOUkp53ZGMLG7XDcdRLokg90ClFyNRSqm2wZ0SwQzgDyJytf26O/CB50LyrMjQgw3EOQcqvBiJUkq1De4kgnuwSg7n2q8/BO71VECe5loiyCnWEoFSSjWbCEQkEHgUeNcY85vWCcmzXNsI9hVriUAppdyZa+gioF+rRNMKXBNBVlE5Dl2oRinl59ypGpoPPCYioYBztjZjzJeeCsqTXKuGNmeXUF6t000opfybO4mgrqvoK/azYPUcapfDcgMDhIynp/DWwu385fuN7Cko93ZISinlVe4kgic8HoUXDOwaDcCKnQXOfTW1DoIC/WqtHqWUcquxuAPwnTFmXuuE1DpG9YwjMiSQqct2OvdVaSJQSvkhv2ssrhMVGsSFo7qzPrPYua+qRpexVEr5H79rLHY1aWhXpi7d5XxdqYlAKeWH/K6x2NWJfesvtFZZfTARLNqWR+foMPp3jmrtsJRSqlW5kwiepJ3PL3QooUGBnNy3E4u35wNQVXuwK+nVby4FIOPpKV6JTSmlWos7C9M83gpxeM0715/AQ1+t46tVmZz14gJWPzaRy/6jU1IrpfzHIRuLRWSliEwUkUh7MZpB9v6LRcRn5m0ODwnkwpHdnK+f/C6N9JwSL0aklFKtq7leQyOBOCAMazGaurtlCNBuVyhrSkjQwT/D+swiL0ailFKtr6VO8z7ZNtBQaNDBdu/sIp2ITinlX1pqI7gRmIiVEO4QkYuAZE8H1dpCXUoExRU1XoxEKaVaX0uJ4ByX7Ytctt0qKYjIJOBlrK6mbxljnm7inF8Bj9ufucYYc3XDczytqlbHDyil/FdziWDC0XywPT3Fq1glij3AchGZboxJczknGXgQONUYUyAinY/mO49UUly4N75WKaXahEMmAmPMT0f52WOAdGPMdgAR+RS4EEhzOed3wKvGmAL7O3OO8juPSOfoMDKensKfvljLZ6m7vRGCUkp5jSdnWOsOuN5V99j7XA0ABojILyKyxK5KakREbhaRVBFJzc3N9VC4EB3mzvg6pZTyLd6eajMIq/F5PHAV8KaIxDY8yRjzhjEmxRiTkpCQ4LFgojQRKKX8kCcTQSbQw+V1kr3P1R5gujGm2hizA9iCF3slua5eppRS/uKQdz4Reay5Nxpjnmzhs5cDySLSBysBXAk07BH0NVZJ4F0RiceqKtrewud6jGs30joOhyEgQLwQjVJKtY7mfgI/7rJtsGYdrdsGazK6QzLG1IjIHcAPWN1H3zHGbBCRJ4FUY8x0+9jZIpIG1AJ/MMbkH/5lHBulVY3XL66qdRAW0O4nWlVKqUNqLhFcbj9PAE4H/oFVlXQ34NasbMaYGcCMBvsec9k2wH32w+uuPakXxsDzszZT67DyXWW1g7BgTQRKKd/VXPfR/wKIyPPAX40x79ivBfhj64TXuiJDg7h1fD96dozg9qkrAaisrQWCvRuYUkp5kDuto6HAn0UkCatEcAPe723kUVOOS6Ss6jj+8MXaeovVKKWUL3InEfweeAuoq9KpwJqDyKfVzUiqy1cqpXydOwvTTBWROcBJ9q4l3hoB3JrqZiStqG7cgKyUUr7E3SqeE4AzgHSsXj4jPBdS21DXlfTqN5d4ORKllPKsFhOBiNwDfAvcCXQFLgGe82xY3lc3EZ1OS62U8nXulAjuAT53eT0HGO2RaNqQ5C7R3D6hHwGCsyupUkr5IncSQRywxuV1BNYAMZ/XKTIUh4Hi8mpvh6KUUh7jTq+hZcCt9vbvgbHALx6LqA3pFBUCwP6yKuIiQ7wcjVJKeYY7JYI7gXKsKSYmAXuxqot8XlyEnQhKq7wciVJKeU6zJQJ7lbEBWA3EdR3qNxtj/KJPZUe7FPD6T9tI6RWHNahaKaV8S7MlAvuG/zaQYoxJsx9+kQQAZ3XQnI05LNux38vRKKWUZ7hTNfQxcL2IDBWRjnUPTwfWFnRyaRf4bLkuYamU8k3uNBbfhTX19FqXfcbN97ZrYcGBrH5sIg9/tZ6lWiJQSvkod27mCzi4BoHfiY0I4fhecXy/bi/ZRRV0jQnzdkhKKXVMuTPX0PhWiKNNG90rDoCVuwqYPDzRy9EopdSx1WIisNcfuBIYDtT9HDbGmPs9GVhbMiSxA6FBAazYqYlAKeV73KkaehW4hcbLVfpNIggJCmBItw68/fMOosOCuPakXnSKCvV2WEopdUy402voYmCqvX03MA94ymMRtVHXnNgLgJfmbOX+z9e0cLZSSrUf7s41tNDe3gt8AdzssYjaqEuPT+KXB84AYOXOAi9Ho5RSx447iSAbqwopG2ulshfcfJ/P6R4bzu9O60NljQNj/LYjlVLKx7hzQ38E2IbVJlABFOEncw01JTEmnMoaBwVlOiOpUso3uNN99COXl596MJZ2oVus1XFqb1G5cy4ipZRqz9zpPvpjE7uNMeZMD8TT5nWNsVYuyyqsYGi3GC9Ho5RSR8+d7qPjm9jntxXkfeIjAdiac4CJQ7p4ORqllDp67iSCBJftOOBxrN5DfikmPJikuHDSsoq9HYpSSh0T7jQWG5dHMbAZuM6TQbV1QxI7aCJQSvkMd0oEeTSuCtrsgVjajeQuUczZuI9ahyEwQBerUUq1b4c7+2gtkAE876mA2oOuMeE4DOSVVNKlg85GqpRq33T20SPQ1b75ZxdVaCJQSrV77nQffaeZw8YYc9MxjKddqEsEe4sqGNHDy8EopdRRcqdq6Hoazzzquu13iaBLjDXz6L7iCi9HopRSR8+dRPA8MAZ4EquX0SPAcvx4lHF8ZChBAcLeIk0ESqn2z51E8BvgCWPMjwAiMgD4kzHmDx6NrA0LCBD6JUSxKVu7kCql2j93EkE58HcROQmrSugCIN+jUbUDI3rEMDttH8YYrEXclFKqfXJnQNlvsZLBtcA1QJm9r0UiMklENotIuog80Mx5l4qIEZEUdz63LRjRI5aCsmp27y/3dihKKXVU3Ok+OldEegGD7F2bjDFVLb1PRAKxlrmcCOwBlovIdGNMWoPzorFWPlt6uMF704ikWABW7S6gZ6cI7wajlFJHodkSgb1wPfaNPxHrpn66m589Bkg3xmy33/8pcGET5z0FPIO11kG7MbBrNGHBAazaVejtUJRS6qgcMhGIyFxgtr19EzADeBaYKSKPuPHZ3YHdLq/32Ptcv2M00MMY831zHyQiN4tIqoik5ubmuvHVnhccGMCIpFjeW5TBz1vzdMUypVS71VyJYBhQd4O+xX5+CvgJ+N3RfrGIBAAvYq181ixjzBvGmBRjTEpCQkJLp7ea5y8fQUx4MNe8vZR//pju7XCUUuqINJcIYoB8EYkBRgG7jDGPA+8Dnd347EzAddxtkr2vTjRWspkvIhnAScD09tRg3KNjBM9edhwAi7f5fUcqpVQ71VwiyMD6tf6Rfd5Me39P3Os+uhxIFpE+IhICXAlMrztojCkyxsQbY3obY3oDS4ALjDGph30VXnTO0K6c3LcTheXV9H7gez5YnOHtkJRS6rA0lwgeBQYCU7Cmon7B3n8l1k27WcaYGuAO4AdgIzDNGLNBRJ4UkQuOKuo2pmNUCBv3WoPL3v0lw7vBKKXUYTpk91FjzOf2esV9gY3GmBIRCQKuBrLd+XBjzAysRmbXfY8d4tzx7gbd1nRyWcQ+PDjQi5EopdTha3YcgTEmH5dqIPtX/hpPB9XedIoMdW5HhGgiUEq1L+6MLFYt6Bh1sESgnUiVUu2NJoJjwLVqaH9pi4OulVKqTXFn0jnVgpRecZx3XCJF5dWs3l3o7XCUUuqwtFgiEJGBIvKmiMwWkR/tx9zWCK696NwhjH9dPZoT+3TkQEUNFdW13g5JKaXc5k6J4GusbqSutCq8CV1jwgFrLePe8ZFejkYppdzjThtBR+AfWJPOJdgPd0YW+52kOCsR/G3GRiprtFSglGof3EkEHwD9gSiskkDdQzVQlwhmpe1j2vLdLZytlFJtgztVQ/dj3fjPc9ln3HyvX+naIcy5Xa7tBEqpdsKdm/kCtATglqDAgwWsfcWVXoxEKaXc584KZeNbIQ6f8d2dY7n0tUWk55Swbk8Rw5NivB2SUko1q8VEYK9SdiUwHKir+zDGmBbXEfBHw7rHcGLfTvy0JZeftuTyze2nMqJHrLfDUkqpQ3KnauhVrIVpDCD2PoMbC8r4q9MHJLA+s4j9pVV8uGSnJgKlVJvmTq+hi4Gp9vbdwDyslcrUIdw0tg8rH53IOUO7sGJngbfDUUqpZrmTCOKAhfb2XuAL4GaPReRDhiTGkJFfSmlljbdDUUqpQ3InEWRjVSFlA29hLVCjk9W5YUi3DhgDm7IPeDsUpZQ6JHdu6I8A27DaBCqAIuAeD8bkMwYnRgOQZq9eppRSbZE73Uc/AhCRWKCXMUY7yLupe2w4HcKCWLAll6zCcu49awAhQVqYUkq1Le7MPtpbRJZjrVt8moj8JCJPej609k9EGJzYgdlp+3ht/jZmp+3zdkhKKdWIOz9P/wN0x+o66sAaaXylJ4PyJcO7HxxQdvvUlczfnOPFaJRSqjF3EsEpwL9cXm8DkjwTju+ZMKj+RK0Lt+Z5KRKllGqaO4kgDxhmb3fGKg1keSwiHzOmT0cAzrATQlNLWS7alsc5/1igC9oopbzCnZHFbwJ/tbc/tp8f8Ew4vic4MIBNT00iJDCAa99Zyo680kbnPPltGpv3HSA9p4Rh3XVuIqVU63Kn19DfRSQLmGLv+s4Y84Fnw/ItYcGBAPTqFMn3a/dS6zAEBojzeHSY9Z+huLzaK/EppfybW30ZjTHvG2N+ZT80CRyhccnxFJVX898Ve+rtjwq1EkHa3mIdhayUanWHTAQiUtvMQ+9WR+DsIV1J6RXHH/+7lmvfXsrCrbkUlVcTFRYMwF++38jkVxa28ClKKXVsNVc1JFizjGYBha0SjY8LCBDe/E0Kt368goVb81i4NY/Lj6/fAWtnfhlVNQ4deKaUajXN3W3eBUqBeGAdcJ8xZnjdo1Wi80FxkSG8dd0JTBiYAMCOvMaT0uUcqPBGaEopP3XIRGCMuQlIBG4DegAzRSRDRCa1VnC+Kio0iHdvGMOlo5PYU1BOSWUNgQHCnyYNAiC7SBOBUqr1NFv/YIwpBbYDO4AqrNJBdCvE5Re6x4Wz70AFRWXVjB+QwJmDrbEGWZoIlFKtqLnG4odFZCvwI9AfuBNINMZ83lrB+bqk2HCMgfTcEiJDg0iMsVYCzS4q93JkSil/0lxj8VNYjcXbsUYXXwBcYC1hjDHGXOj58Hxbj44RANQ6DJGhgUSHBdMhLIhd+8u8HJlSyp+0NKBMgH72w5XxTDj+ZWSPWEKCAqiqcRAZYv2n6JMQRUaeJgKlVOtpLhH0abUo/FR4SCAn9unIwq15ztHHfeMjWbZjv5cjU0r5k+Z6De1s7tGaQfqyR88bwmnJ8YwbYHUn7dUpgszCcu6bttq7gSml/IZHRy2JyCQR2Swi6SLSaKI6EblPRNJEZK2IzBWRXp6Mpy0a0CWaD2860TlL6fiBVs+hL1dm8v3avVTVOLwZnlLKD3gsEYhIIPAqcC4wBLhKRIY0OG0VkGKMOQ74AnjWU/G0FyN7xPLSFSMBayGbp75L825ASimf58kSwRgg3Riz3RhTBXwK1OtpZIyZZ4ypaxldgi54A8DY5Hi6x4YTHhzIJ8t2sShdF7NRSnmOJxNBd2C3y+s99r5DuQn4X1MHRORmEUkVkdTc3NxjGGLbFB8Vyi8PnMGc+08nJjyYP3251tshKaV8WJuY2UxErgFSgOeaOm6MecMYk2KMSUlISGjd4Lyoe2w4t0/oz+795WQVNj/IbMu+A1zx+mLmbtxHZY2udKaUcp8nE0Em1hxFdZLsffWIyFnAw8AFxphKD8bTLp3Y12pEnrFub7PnfZ66m6U79nPT+6lMS93T7LlKKeXKk4lgOZAsIn1EJARrrePprieIyCjgdawkkOPBWNqtIYkdOC05nmd/2MzGvcWHPK/SpXdRTnEFbyzYpg3NSim3eCwRGGNqgDuAH4CNwDRjzAYReVJELrBPew6IAj4XkdUiMv0QH+e3RIR/XDGSyJBAXp2XDsDu/WXklRwsPDkcho17ixmc2IFOkSHklVTx1aosZqft81bYSql2xJ3F64+YMWYGMKPBvsdcts/y5Pf7ivioUE7tH8/KnQUYYzjt2Xl0iwlj0YNnAvDy3K0szyhgTJ+OCGFkFpazLaeEiNBAL0eulGoP2kRjsWrZqJ5xZBVVMHN9NmBNVX3Za4v4elUm87dYPakmD+tK15gwlmzPp6rWQXF5NQ6HTgullGqeJoJ24oTecQDc+vFK577UnQXcN201O/NLOWtwZ64/tQ9dOoQ6RyM7DJRU1VDrMEx6aQH/a6HBWSnlnzQRtBNDu8U4t4d17wBAYkwYDgOFZdX07BgJQFJcRL33FZdXs6+4gk3ZB7jns9WtFq9Sqv3QRNBOBAYIpyXHExIUwOXHW71y6+YnAujZMRyAE3p3rPe+ovJq5xiEyhoH363NorpW5y9SSh2kiaAdefu6E1j92ETOHNyZ2Ihgbjm9Hym9rCqjk/p1AmBEj5h67ykqrybTZTDaHVNXMWuD9iZSSh3k0V5D6tgKCQoghAAiQoJY/djZALx1XQq1DkOnqFAAQoMCeef6FEoqa7nrk1UUl1eTnlNS73PW7ilkynGJrR6/Uqpt0kTQzsVGhDTad8agLuwpsObym7Mxhy9W1B9pvD6rqFViU0q1D1o15KMSokPp2iHMmQT+b1xfHpkymKvG9GB9ZjHGHLpbaa3DsLuJdZPTsoq57LVFFJVXeyxupVTr00Tgo0KDAvni1pNJ6RXHPWcl8+Dkwfz2tL4M7x5LUXk1O/Mb3+hLK2t495cd3PnJSk57dh6FZVX1jt/z2SpSdxawcldBa12GUqoVaCLwYUlxEXxx6yncc9YA5766rqfjn5/Pgi31p/R+bf42nvg2jRnrrEFrm7MPOI8ZY9iyz2pr2FdUccxiLCyranYOJaWU52ki8DMDu0Y7t5/6Ls1ZRbS/tIp3f9lBWHAA/TtHAbAlp4RnZm5i8ssLeX7WZuf7djVRbXSkXpi1hXNfXsgbC7Yds89USh0eTQR+JjQokO/vGsvdZyazNaeEQY/OpKismsv/s4iy6lqm3zGW2feOIzIkkG9XZ/Ha/G2k7S3m1XkHb9Q77UTg2s6QWVjOuj31G6H/PT+d1+Y3f4PflmuVMp77YXO9ifTauubaWJRqbzQR+KGh3WK4dLS1KmhljYPffrCc7XmlvHr1aAZ0iUZEOH1gAssy9gPQtUMYAJEhgUwYmMCKjAIqqmu5fepKrntnGQATX/yJ8//1s/MGWVXj4NmZm3lm5qZmY9lfWkVcRDDVtYbVuwo9dMXH1vQ1WQx8ZGaTDepKtUeaCPxUz04RfHnbKQAszyjgtvH9mDz84NiCf1wx0rl9/9lWG0NSXAQ3je1LdnEFv35rKTPWZfPTllwKSqsoq7JWRcuy2w8Wb893vv/bNVmHjGNfcQUTBnYmQKzxDe3Bjxv3UVXr4J8/bvV2KEodE5oI/NiIpFgAQgID+P3ZA+sdCw0K5Id7xvH2dSkMTrQamKccl8jY5HguHNmNFTsP9hwa9dRs5/aGTKt66NUf05377vxkFbkHGlf7rNldSEFZNb3jIxnQJZq19nurahzc/emqeo3VbUl1rVXq2ZZb6uVIlDo2dECZHwsMEL67cyyJMWGISKPjA7tGOxuXv7tzLEPshPDir0Zyw6l96BEXzhVvLKk3cvndXzLoEx/Jsoz93DyuL28s2A7ArLRsBnXtQN/4SDZkFVNSWc0tH1kzqcZFBDO8ewxzN+VgjCE9p4RvVmfxzeosVj82sclBc0eqpLKGqNCj+2efXWyVeuoG7SnV3mki8HPDuse0fFKD8wIDhJE9YgGYc9/pVNbU8tXKTIorqvnbjE38/vM1iMCNp/bh3rMGMPmVhbwydyv7iptuDD4uKRZE+HzFHvYUlPNLep7z2B1TV/HRb090rqsQENA4YaXnHGDxtnyuOalXkwmtzjerM7n709W8c30KZwzqwsKtuczdmMPDUwYTHOh+4Tjbrv7aV1xJZU0toUG6AJBq37RqSB210KBArhzTk4tHWQ3Qa/YU8esTe9I1JozwkEDOGdq1XhL4zcm9ABg3IIGtfz2XET1iGZFkJZrbp67krzM2Os/9OT2PD5fs5JyXFnD7VKsE4XCYer12HvpyPY9+s4Gxz8xjf2n9QXCunp1pdYG98b1UnvthE9e+vYz3FmU4Sy3Nmbcph1s/WkGtw7CvuMLZgJ5VWMHO/FL2FpW38AlKtV1aIlDHTEJ0KLERwRSWVfP4+UOd+3+VksSibXmc3LcT15zUix4dIzi5bydSend0/hIf1i2G7rHhrHXpgrrs4TO58F+/8OjX6wHYmlNCSWUNd3+yik3ZB0iMCeOkvp2cvZsyC8tZtmM/k4Z1bRTbrvyyerOw1nWHTYwJY+rSXdw2vl+zpYkb3lsOwNRlu6hxGCYM6swny3axeFs+Hy7ZSYewID77v5OP9E8HwPrMImZtyObeiQOajeVwvDh7C+k5B3jqwmHOiQmVakhLBOqYmnXvOFY9OpEgl6qWvglRTL9jLA9OHkyPjtbCOecOTyQh+uCNKSBAuOvM/s5ptQE6R4fxwLmD6n3+C7M2M3dTDtnFFaTuLOBf86xG6devPR4R2LKvcQNzTa2D8c/Pa7S/f+co/nDOQDILy5mV1vzU3IF2ldSjX6+nW0wYj543GICHvlrHxr3FpO4soKjMvTmYmhqDsCm7mAtf/YVXfkxn8bb8Jt51+DZlF/PK3K3MWJfN626UetThq651+MRysJoI1DHVOTqMuMgja9y94oSefHHrKXSODiU6zCqs1i2+0zEyhOBA4d1fMgD4969H88+rRpHcOYof7z+dc4Z2pWfHCDZnH+C2j1cw5LGZPPHtBorKq3ll7lYcBi4a2Y3TkuMBuP6U3rx/4xjOO64bgxM78PBX68hvYkDbWwu3s3JXAYEuv9A/ufkkIkKCeOy8Ic59tQ7Dgq25jd7f0M78UoY/PouHvlrHr15fTE5xBUXl1Ux6aSG19g3ls9TdzvPX7Sk64iVGF6VbCWV0z1impe5uNHfUoZRX1TJj3d56CcvhMO1mssGswnJ6P/A98zblePR7jDEkP/w/HvpqnUe/pzVo1ZBqcxb8cQJ196DEmHA+/u2JDOsew4+b9vH49DSGd4/htOR4IkKCOH9EN+f7RiTF8t3aLOp+oL37S4YzcQA8NGUwESFB5BRX0Dchyrn/pStGMuWVhbyxcDsPnjvYuX/3/jL+8v3B9orTkuP54zmD6NXJWhb0xrF9eO2nbc6usfM25XD+iG4YY6hxmCYboOdvzqWksoapS3cB8NT3G53tI3UWb8unutZBzoFKzv/XzwBcfWJPosOC+NM5gwgIEIrKqnniuw0M6BLN/43r22RVUmZhOWHBAfzlouGc/6+f+f3na3hkyhB6x0c2+/e/85OVzNmYw3d3jnV2Enjtp20898NmVj46kY5HmOhbS914lLd/3sGEQZ0bHd+9v4yIkEBnVZnDYXjq+zTGJSdwcr9OhAW71/hf1+716fLdPH3pcUcc76JteSREhZLcJZpahyFAOGZVg+7SRKDanIb/I57a3/oVf/GoJGeDdFP+fP4QptuD1y4Z1Z2U3h1555cddIwMoXN0KJ2jrQbeKJckAFY32fEDE/h6VSaFpdVsyi7m7KFdmb66/kC4f141qlFX1icuGMp901YzLjmBmRuy+dX2fOZtyuH1BdtJ/+u5AOSXVuEwhspqB2t2FwLw14uHsXVfCe8tyuDbNVl0jg4l50Alpw9I4KctuSQ//L9631OXOEb1iGXSsERenruVL1dmAtA5OpRLRifx7Zos/rd+L8O6x3DLuH5kFZbTPTacId06cMmo7ny+Yg8LtuSx7OEz611HRXUt901bzZjeHbn25N7M2Wj9kp63KYfd+8s4a0gX56DA537YxN8uHs7u/eWEBgfQxW40N8Zw60criYsM4fELhrA+s5jO0aHOqsA66TklOIxhze5C0nNKyC+t4vYJ/enTQnI6HHW9usqqapo8ftqzVjXhyB6xfH7LyXy3Nsv5oyEmPJj3bjiBrMIKBnaNon9nq/t0da2DrMJyenWKJPdAJR8sznBeO1g381P6xbcY25sLtlNZU8sdZyQDVhK6+s2lAGQ8PYWb3l/OluwDfH37qXR2+XxP00SgfIZrY+it4/uR3CWaq0/s6dZ7rzihJ3M2pjqrZdbsKSIiJJBrTupJZEgQSXHhTY5nmDw8kcnDE8kuquCat5dy43vLnaOs1+wp5IsVe/hkmfWZUaFBRIYGcu6wrvz6xF6kZRXz3qIMOkaG8JeLhjFxSBcKy6qZ8MJ863lgAqWVtSzL2M9lxycxO20fHy/dxbrMIj5bvovxAxMoKKvmvmlreGXuVjLsqcVnrMtmXHICmYXldI+zbsT3nT2A7XmlrNhZwK/fWsq3d4x1dsV9bf42ZqzLZsa6bBJjw53X9sLsLQDER4WQV2JVK32ybDcn9e3EvZ+txmHggxvHMG5AAj+n5zFzgzVr7Z6CMhZuzaNXpwh++sMEAGZtyGZ9ltVm0VBhWRVPXjiMsqpa54SHrvJLKokICSI8pP4PhKLyal7/aRu3ju9HdFgwAAu25PL4t2kA5DQxiLGiuta5vXp3If9dsceZ+Oo+8+J/LwKgS4dQljx4JiLC+4sy+Mv3G/n69lNZuCWXf7oMmAS4+s2lZDw9pdH3Adzy4QqS4sL5ZNkuSu1/G3WJYLNLm1ZFdS3zN1vVi1OX7ao3a7CnSXubPCslJcWkpqZ6OwzVRv24aR9v/7yD928YU6/BuiXGGK58YwlhwYHWL/LOUcy+7/TD+u6cAxX86j+LnTfks4d0abIR+tWrRzuXCs05UEFCVGi9qoDsogpKKmucN8WSyhoiQwK569PV9abr+NvFw5k4pAv3TVvNwq159b7jkSmDeW3+Ns4e2oW/X3Kw2uK1+dt4ZuYmjkuK4f0bxhAWHMjJT8+lR1wE6zKbX7nu1P6dSM8paTQe5JLR3dm9v4ztuaVMHp7Ih0t2Oo/17BhBREggmxqMEh/VM5ZVTcwt9fuzB9A9LtxZ8quudTDmr3Morqhh2v+dzOiesc6/1d9mbOSNBdu556xk7jlrAJuyi5n00sJ6nzf/9+PrVYVtzj7AOS8tcL6eOKQLK3YWHLLb8ex7x5HcJZoH/ruWT5fvbvKcOjv+PrlRlU5OcQVj/ja30bkTBiZw/ohuLNqW32gFQYC+CZF8ddup3PfZau4/eyBDunVo9rvdISIrjDEpTR7TRKCUxeEwBAQIS7fnM7Br9BGNaK6ormVDVjHfrsnivUUZzv2uv6o3PTXJ7XpoV+k5JVzx+mKGJ8Vw8ajunH9cNwICBGMMt328kp35Zbx9fQrXvLXUOf3F05cM58oxB0tFG/cWc+7LCxt99he3nMwfv1jL9jzrfded3Iv3F1s39AtHduOb1VlcOLIbpw9I4L5pa0iIDuWRKYO5+9PVzs+4Y0J/7j97APO35JKWVcxzP2xu9D11xg9MYMLAzvx5+oYmj08a2pWbTuuDAJf9Z7Fz/8tXjuTCkd0BGPvMj+wpsLoER4cGUWsMZVW1iMDdZybz0pytXHZ8Es9ddhw/bcklIiSI/JJKbv14JW9fl8LM9dl8bt+E7zqjP3M25pBmr40xqmcsq3cXctv4fkSEBDW6lnOGdsEYuH1Cfx7/dgOrdhWy6tGJzo4Sf/h8DYXl1VTXOpy/8g/l/BHdWJGx3zlP191nJvOyS8mpT3wkM+857agHLjaXCLRqSClbXVXJiX07HfFnhAUHcnyvOEb3jGXikC4EBgiDu3YgNDiA9xZlcNbgzkeUBMDq7rr4wTMJDpR6vzxFhNeuOd75+sFzB/O7D1MZl5zA5Sk96n1GcoOql9CgAM4d1pWU3h358Lcnsm5PEaWVNZw9tAvFFTXcPqE//RIimTCwM+MGJNAxMoSOkSEM7Rbj7Nl11uAunD20C1OGJyIiTBjYmfEDEugbH8mtH6+kb3wkXTqE1ZuI8M4zkjm+VxyXjO7OC7O2OJNm/85RJMaEMXNDtrOqCay2mD9P38Ddn67m89Q9VNbUsqegnH4JkWzLLeVApdUe8Oh5Q7jyhB5EhgaRX1LFh0t2sig9z3mTBQgKEE7q24mSyhpnIrhkdBL3ThzAx0t3MWFQZxKiQrlj6ko+WLyTAxU1zvfVOAx3TOjP78b1JSbcqo76v3H9uOWjFewpKMcAeSWVzs8FawqVgkN0Le4UGcJzlx1HZY2D/JJKAkSICQ+ulwh25JXyl+828tRFw5r8jGNBSwRK+aCismo6hAc12fvk89Td7C4oJzEmjCvsRNHU1B3u2F9aRYewoENWw+3KLyMxNozgwAD+8Pka5m3OZe59pxMTEVzvvMKyKr5YsYdzhnalR8cIisqquW/aaubaXUB3/H0yT/9vE68v2E5IYABVtQ4A/nPNaOecVXXn1V1zrcMw5ZWFzmqp8QMTmL85l34Jkcy9fzyllTUM/fMPAE3W7+/IK+WqN5Y455bqlxDJl7ee2ij29ZlFnPfPnzljUGfWZRY5e5FdOjqJ2yb0o6K6limv/My4AQlkFpTx31tPYfqaLB77ZgN3n5nMvRMbtwVszj5Ap6gQYsODeeTr9Xy1KpPv7xrrbLw+Elo1pJRqE4wxh9U1sqK6loKyKhJjwqmsqSWvpIpOkSHc/ekqftiwj7Qnz+FARQ0fLt5J15gwrjmpV73355dUknOgkqAAoUfHCO78ZBW/O62vc3zKp8t2ERUWxHnHdWvq69lTUMYLs7bw1apMUnrF8cWtpzQ6p6yqhrNe+KleqSMxJoy5959ORIhValqfWcSQxA7OhJtTXMHf/7eJx84b0uK4m937y7j4379QWlnL05cOd1aNHS5NBEopn1LrMOSXVjq7BHvatOW7GZscTzeXXlWuqmsd3PbxSman7eOvFw/j6jE9j+lYgKzCcp74dgN3nZnM0G7uTRTZkCYCpZTysILSKl75cSv3nz3wqKc69wRtLFZKKQ+Liwzhzy6TLbYnOteQUkr5OU0ESinl5zyaCERkkohsFpF0EXmgieOhIvKZfXypiPT2ZDxKKaUa81giEJFA4FXgXGAIcJWIDGlw2k1AgTGmP/AP4BlPxaOUUqppniwRjAHSjTHbjTFVwKfAhQ3OuRB4397+AjhTWnv+VaWU8nOeTATdAddZmvbY+5o8xxhTAxQBRz6+Xyml1GFrF43FInKziKSKSGpubsurQCmllHKfJxNBJuA641WSva/Jc0QkCIgBGi3Yaox5wxiTYoxJSUhI8FC4Sinlnzw5oGw5kCwifbBu+FcCVzc4ZzpwHbAYuAz40bQw1HnFihV5IrKzuXOaEQ/ktXiWb9Fr9g96zf7haK6516EOeCwRGGNqROQO4AcgEHjHGLNBRJ4EUo0x04G3gQ9FJB3Yj5UsWvrcIy4SiEjqoYZY+yq9Zv+g1+wfPHXNHp1iwhgzA5jRYN9jLtsVwOWejEEppVTz2kVjsVJKKc/xt0TwhrcD8AK9Zv+g1+wfPHLN7W4aaqWUUseWv5UIlFJKNaCJQCml/JxfJIKWZkFtr0TkHRHJEZH1Lvs6ishsEdlqP8fZ+0VEXrH/BmtFZLT3Ij9yItJDROaJSJqIbBCRu+39PnvdIhImIstEZI19zU/Y+/vYs/am27P4htj7fWZWXxEJFJFVIvKd/dqnr1lEMkRknYisFpFUe5/H/237fCJwcxbU9uo9YFKDfQ8Ac40xycBc+zVY159sP24GXmulGI+1GuB+Y8wQ4CTgdvu/py9fdyVwhjFmBDASmCQiJ2HN1vsPe/beAqzZfMG3ZvW9G9jo8tofrnmCMWaky3gBz//bNsb49AM4GfjB5fWDwIPejusYXl9vYL3L681Aor2dCGy2t18HrmrqvPb8AL4BJvrLdQMRwErgRKwRpkH2fue/c6xBnCfb20H2eeLt2I/gWpPsG98ZwHeA+ME1ZwDxDfZ5/N+2z5cIcG8WVF/SxRiz197OBrrY2z73d7CL/6OApfj4ddtVJKuBHGA2sA0oNNasvVD/unxlVt+XgD8CDvt1J3z/mg0wS0RWiMjN9j6P/9vWxet9mDHGiIhP9g8WkSjgv8A9xphi12UsfPG6jTG1wEgRiQW+AgZ5NyLPEpHzgBxjzAoRGe/lcFrTWGNMpoh0BmaLyCbXg576t+0PJQJ3ZkH1JftEJBHAfs6x9/vM30FEgrGSwMfGmC/t3T5/3QDGmEJgHla1SKw9ay/Uvy63ZvVt404FLhCRDKxFrc4AXsa3rxljTKb9nIOV8MfQCv+2/SEROGdBtXsYXIk166mvqpvRFfv5G5f9v7F7GpwEFLkUN9sNsX76vw1sNMa86HLIZ69bRBLskgAiEo7VJrIRKyFcZp/W8Jrr/hZuzerb1hhjHjTGJBljemP9P/ujMebX+PA1i0ikiETXbQNnA+tpjX/b3m4caaUGmMnAFqx61Ye9Hc8xvK5PgL1ANVb94E1Y9aJzga3AHKCjfa5g9Z7aBqwDUrwd/xFe81isetS1wGr7MdmXrxs4DlhlX/N64DF7f19gGZAOfA6E2vvD7Nfp9vG+3r6Go7z+8cB3vn7N9rWtsR8b6u5VrfFvW6eYUEopP+cPVUNKKaWaoYlAKaX8nCYCpZTyc5oIlFLKz2kiUEopP6eJQPk9EektIqbBo9AD3/O4/dmXtXy2Uq1Hp5hQ6qBVwLP2dpU3A1GqNWmJQKmDcrEG7MwB5orI9fYv+I/s+eHzROT3dSeLyO/sOeJL7fUCxtr7Q0Tk7yKyU0TKRWRBg++ZICKbRCRXRC6333OqPad8hb3/k9a6aKU0ESh10NlYySCXg8P4ASZgzfWeDTwnIiNE5AyshcRzgfuAnsB0EemENV/8A1ijQ+/Amjba1Zn258UAT9v7/og1svR24EmsaZSVahVaNaTUQUuBR+ztAmC4vf2OMeZ1EakB3gJOx7rxA/zZGDNbRHoCD2EtlnM+1jQYVxhjDjTxPS8aY94QkVuxFhUBa/qA87Cmy1iJNXWAUq1CSwRKHZRnjJljP1a47JcGz65Mg2d37Lefazj4/+CfgEuwEsJNQGrdRHNKeZqWCJQ6qJuIXOnyOth+vkFEdgF32a9/wpoI7H7gCRHph71UIrAE+BZIAT4TkS+A44wx97Tw3Q9iLUm5AWuxkT5AB6DwKK9JqRZpIlDqoFFYM7rWudd+ngvcBnQF/mCMWQNgryD1R+BFIA241xiTLyJPA+HAr7Hm0V/mxnc7gDvt78jHmmF011FfkVJu0NlHlToEEbkeeBfr5v+8l8NRymO0jUAppfyclgiUUsrPaYlAKaX8nCYCpZTyc5oIlFLKz2kiUEopP6eJQCml/Nz/A787I3wF6i8DAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train MSE: 0.0398\n",
      "Test MSE: 1.0116\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric_temporal.nn.attention import STConv\n",
    "from torch_geometric_temporal.dataset import ChickenpoxDatasetLoader\n",
    "from torch_geometric_temporal.signal import temporal_signal_split\n",
    "\n",
    "from torch_geometric.data import DataLoader\n",
    "\n",
    "loader = ChickenpoxDatasetLoader()\n",
    "\n",
    "lags = 10\n",
    "stride = 1\n",
    "epochs = 500\n",
    "batch_size = 32\n",
    "\n",
    "dataset = loader.get_dataset(lags)\n",
    "\n",
    "sample = next(iter(dataset))\n",
    "num_nodes = sample.x.size(0)\n",
    "edge_index = sample.edge_index\n",
    "edge_weight = sample.edge_attr\n",
    "\n",
    "train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.4)\n",
    "\n",
    "train_loader = DataLoader(list(train_dataset), batch_size=batch_size, shuffle=True)\n",
    "test_loader = DataLoader(list(test_dataset), batch_size=batch_size, shuffle=False)\n",
    "\n",
    "### MODEL DEFINITION\n",
    "class AttentionGCN(nn.Module):\n",
    "    def __init__(self, num_nodes, lags):\n",
    "        super(AttentionGCN, self).__init__()\n",
    "                      \n",
    "        self.num_nodes = num_nodes\n",
    "        self.kernel_size = 3\n",
    "        self.pad_size = self.kernel_size - 1\n",
    "\n",
    "        self.gnn = STConv(\n",
    "            num_nodes=self.num_nodes, \n",
    "            in_channels=1, \n",
    "            hidden_channels=32, \n",
    "            out_channels=32, \n",
    "            kernel_size=self.kernel_size, \n",
    "            K=1)\n",
    "\n",
    "        self.conv = nn.Sequential( \n",
    "            nn.Conv2d(\n",
    "                in_channels=32,\n",
    "                out_channels=16,\n",
    "                kernel_size=(1, lags),\n",
    "                bias=True,\n",
    "            ),\n",
    "            nn.BatchNorm2d(16),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(\n",
    "                in_channels=16,\n",
    "                out_channels=1,\n",
    "                kernel_size=(1, 1),\n",
    "                bias=True,\n",
    "            )\n",
    "        )\n",
    "\n",
    "\n",
    "    def forward(self, window):\n",
    "        # print(window.x.shape)\n",
    "        x = F.pad(window.x, (self.pad_size,self.pad_size))\n",
    "        x = x.view(-1, num_nodes, lags+2*self.pad_size, 1).permute(0,2,1,3)\n",
    "   \n",
    "        H = F.relu(self.gnn(x, edge_index, edge_weight)) # -> (batch, lags, num_nodes, out_channels)\n",
    "\n",
    "        H = H.permute(0, 3, 2, 1) # -> (batch, out_channels, num_nodes, lags)\n",
    "        x = self.conv(H).squeeze(3) # -> (batch, num_nodes, lags)\n",
    "        \n",
    "        x = x.flatten()\n",
    "\n",
    "        return x\n",
    "    \n",
    "model = AttentionGCN(num_nodes=num_nodes, lags=lags)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.01)\n",
    "\n",
    "### TRAIN\n",
    "model.train()\n",
    "\n",
    "loss_history = []\n",
    "for _ in tqdm(range(epochs)):\n",
    "    total_loss = 0\n",
    "    for i, window in enumerate(train_loader):\n",
    "        optimizer.zero_grad()\n",
    "        y_pred = model(window)\n",
    "        \n",
    "        assert y_pred.shape == window.y.shape\n",
    "        loss = torch.mean((y_pred - window.y)**2)\n",
    "        total_loss += loss.item()\n",
    "        \n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "    total_loss /= i+1\n",
    "    loss_history.append(total_loss)\n",
    "\n",
    "### TEST \n",
    "model.eval()\n",
    "loss = 0\n",
    "with torch.no_grad():\n",
    "    for i, window in enumerate(test_loader):\n",
    "        y_pred = model(window)\n",
    "        \n",
    "        assert y_pred.shape == window.y.shape\n",
    "        loss += torch.mean((y_pred - window.y)**2)\n",
    "    loss /= i+1\n",
    "\n",
    "### RESULTS PLOT\n",
    "fig, ax = plt.subplots()\n",
    "\n",
    "x_ticks = np.arange(1, epochs+1)\n",
    "ax.plot(x_ticks, loss_history)\n",
    "\n",
    "# figure labels\n",
    "ax.set_title('Loss over time', fontweight='bold')\n",
    "ax.set_xlabel('Epochs', fontweight='bold')\n",
    "ax.set_ylabel('Mean Squared Error', fontweight='bold')\n",
    "plt.show()\n",
    "\n",
    "print(\"Train MSE: {:.4f}\".format(total_loss))\n",
    "print(\"Test MSE: {:.4f}\".format(loss))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d25258e0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
